{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Gwll5yy8vJBmbHrUdIMF0zMCdG2n6kaR?usp=sharing)\n"
      ],
      "metadata": {
        "id": "73nSLsIP89Y9"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ee5e50BZmj_h"
      },
      "source": [
        "# Enhancing RAG with Contextual Retrieval\n",
        "\n",
        "Integrating Portkey into our Contextual Embeddings guide significantly enhances the RAG pipeline:\n",
        "\n",
        "- Unified API Gateway: Seamlessly connects to Anthropic, Voyage AI, and Cohere through a single interface.\n",
        "- Simplified Management: Centralizes API key control and usage tracking across providers. using Virtual Keys\n",
        "- Enhanced Reliability: Enables dynamic routing and fallback mechanisms for improved availability.\n",
        "- Cost Optimization: Potential for reduced costs through intelligent request routing.\n",
        "\n",
        "[Link to Anthropic's Github for the original guide](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/contextual-embeddings/guide.ipynb)\n",
        "\n",
        "\n",
        "This adaptation showcases how Portkey can streamline advanced RAG workflows, using it's Unified AI Gateway\n",
        "\n",
        "\n",
        "\n",
        "Retrieval Augmented Generation (RAG) enables Claude to leverage your internal knowledge bases, codebases, or any other corpus of documents when providing a response. Enterprises are increasingly building RAG applications to improve workflows in customer support, Q&A over internal company documents, financial & legal analysis, code generation, and much more.\n",
        "\n",
        "In a [separate guide](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/retrieval_augmented_generation/guide.ipynb), we walked through setting up a basic retrieval system, demonstrated how to evaluate its performance, and then outlined a few techniques to improve performance. In this guide, we present a technique for improving retrieval performance: Contextual Embeddings.\n",
        "\n",
        "In traditional RAG, documents are typically split into smaller chunks for efficient retrieval. While this approach works well for many applications, it can lead to problems when individual chunks lack sufficient context. Contextual Embeddings solve this problem by adding relevant context to each chunk before embedding. This method improves the quality of each embedded chunk, allowing for more accurate retrieval and thus better overall performance. Averaged across all data sources we tested, Contextual Embeddings reduced the top-20-chunk retrieval failure rate by 35%.\n",
        "\n",
        "The same chunk-specific context can also be used with BM25 search to further improve retrieval performance. We introduce this technique in the “Contextual BM25” section.\n",
        "\n",
        "In this guide, we'll demonstrate how to build and optimize a Contextual Retrieval system using a dataset of 9 codebases as our knowledge base. We'll walk through:\n",
        "\n",
        "1) Setting up a basic retrieval pipeline to establish a baseline for performance.\n",
        "\n",
        "2) Contextual Embeddings: what it is, why it works, and how prompt caching makes it practical for production use cases.\n",
        "\n",
        "3) Implementing Contextual Embeddings and demonstrating performance improvements.\n",
        "\n",
        "4) Contextual BM25: improving performance with *contextual* BM25 hybrid search.\n",
        "\n",
        "5) Improving performance with reranking,\n",
        "\n",
        "### Evaluation Metrics & Dataset:\n",
        "\n",
        "We use a pre-chunked dataset of 9 codebases - all of which have been chunked according to a basic character splitting mechanism. Our evaluation dataset contains 248 queries - each of which contains a 'golden chunk.' We'll use a metric called Pass@k to evaluate performance. Pass@k checks whether or not the 'golden document' was present in the first k documents retrieved for each query. Contextual Embeddings in this case helped us to improve Pass@10 performance from ~87% --> ~95%.\n",
        "\n",
        "You can find the code files and their chunks in [data/codebase_chunks.json](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/contextual-embeddings/data/codebase_chunks.json) and the evaluation dataset in [data/evaluation_set.jsonl](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/contextual-embeddings/data/evaluation_set.jsonl)\n",
        "\n",
        "#### Additional Notes:\n",
        "\n",
        "Prompt caching is helpful in managing costs when using this retrieval method. This feature is currently available on Anthropic's 1P API, and is coming soon to our 3P partner environments in AWS Bedrock and GCP Vertex. We know that many of our customers leverage AWS Knowledge Bases and GCP Vertex AI APIs when building RAG solutions, and this method can be used on either platform with a bit of customization. Consider reaching out to Anthropic or your AWS/GCP account team for guidance on this!\n",
        "\n",
        "To make it easier to use this method on Bedrock, the AWS team has provided us with code that you can use to implement a Lambda function that adds context to each document. If you deploy this Lambda function, you can select it as a custom chunking option when configuring a [Bedrock Knowledge Base](https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base-create.html). You can find this code in `contextual-rag-lambda-function`. The main lambda function code is in `lambda_function.py`.\n",
        "\n",
        "## Table of Contents\n",
        "\n",
        "1) Setup\n",
        "\n",
        "2) Basic RAG\n",
        "\n",
        "3) Contextual Embeddings\n",
        "\n",
        "4) Contextual BM25\n",
        "\n",
        "5) Reranking"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZV4u-Qh0mj_j"
      },
      "source": [
        "## Setup\n",
        "\n",
        "We'll need a few libraries, including:\n",
        "\n",
        "`portkey`\n",
        "\n",
        "1) `anthropic` - to interact with Claude\n",
        "\n",
        "2) `voyageai` - to generate high quality embeddings\n",
        "\n",
        "3) `cohere` - for reranking\n",
        "\n",
        "4) `elasticsearch` for performant BM25 search\n",
        "\n",
        "3) `pandas`, `numpy`, `matplotlib`, and `scikit-learn` for data manipulation and visualization\n",
        "\n",
        "\n",
        "You'll also need API keys from [Anthropic](https://www.anthropic.com/), [Voyage AI](https://www.voyageai.com/), [Cohere](https://cohere.com/rerank) and [Portkey](https://app.portkey.ai)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rDgao8kZmj_j"
      },
      "outputs": [],
      "source": [
        "!pip install elasticsearch\n",
        "!pip install pandas\n",
        "!pip install numpy\n",
        "!pip install portkey-ai"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QFBEEul2mj_j"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "os.environ['PORTKEY_API_KEY'] = \"YOUR KEY HERE\"\n",
        "os.environ['VOYAGE_API_KEY'] = \"YOUR KEY HERE\"\n",
        "os.environ['ANTHROPIC_API_KEY'] = \"YOUR KEY HERE\"\n",
        "os.environ['COHERE_API_KEY'] = \"YOUR KEY HERE\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uYZuWLFXmj_j"
      },
      "outputs": [],
      "source": [
        "import anthropic\n",
        "\n",
        "client = anthropic.Anthropic(\n",
        "    # This is the default and can be omitted\n",
        "    api_key=os.getenv(\"ANTHROPIC_API_KEY\"),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Cilm22emj_j"
      },
      "source": [
        "### Initialize a Vector DB Class\n",
        "\n",
        "In this example, we're using an in-memory vector DB, but for a production application, you may want to use a hosted solution.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ly0uWj25mj_j"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pickle\n",
        "import json\n",
        "import numpy as np\n",
        "from portkey_ai import Portkey\n",
        "from typing import List, Dict, Any\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "\n",
        "class VectorDB:\n",
        "    def __init__(self, name: str, api_key = None):\n",
        "        if api_key is None:\n",
        "            api_key = os.getenv(\"VOYAGE_API_KEY\")\n",
        "        self.client = Portkey(\n",
        "            api_key=\"PORTKEY_API_KEY\",  # Replace with your Portkey API key\n",
        "            provider=\"voyage\",\n",
        "            Authorization=api_key\n",
        "        )\n",
        "        self.name = name\n",
        "        self.embeddings = []\n",
        "        self.metadata = []\n",
        "        self.query_cache = {}\n",
        "        self.db_path = f\"./data/{name}/vector_db.pkl\"\n",
        "\n",
        "    def load_data(self, dataset: List[Dict[str, Any]]):\n",
        "        if self.embeddings and self.metadata:\n",
        "            print(\"Vector database is already loaded. Skipping data loading.\")\n",
        "            return\n",
        "        if os.path.exists(self.db_path):\n",
        "            print(\"Loading vector database from disk.\")\n",
        "            self.load_db()\n",
        "            return\n",
        "\n",
        "        texts_to_embed = []\n",
        "        metadata = []\n",
        "        total_chunks = sum(len(doc['chunks']) for doc in dataset)\n",
        "\n",
        "        with tqdm(total=total_chunks, desc=\"Processing chunks\") as pbar:\n",
        "            for doc in dataset:\n",
        "                for chunk in doc['chunks']:\n",
        "                    texts_to_embed.append(chunk['content'])\n",
        "                    metadata.append({\n",
        "                        'doc_id': doc['doc_id'],\n",
        "                        'original_uuid': doc['original_uuid'],\n",
        "                        'chunk_id': chunk['chunk_id'],\n",
        "                        'original_index': chunk['original_index'],\n",
        "                        'content': chunk['content']\n",
        "                    })\n",
        "                    pbar.update(1)\n",
        "\n",
        "        self._embed_and_store(texts_to_embed, metadata)\n",
        "        self.save_db()\n",
        "\n",
        "        print(f\"Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}\")\n",
        "\n",
        "    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):\n",
        "        batch_size = 128\n",
        "        with tqdm(total=len(texts), desc=\"Embedding chunks\") as pbar:\n",
        "            result = []\n",
        "            for i in range(0, len(texts), batch_size):\n",
        "                batch = texts[i : i + batch_size]\n",
        "                batch_result = self.client.embed(batch, model=\"voyage-2\").embeddings\n",
        "                result.extend(batch_result)\n",
        "                pbar.update(len(batch))\n",
        "\n",
        "        self.embeddings = result\n",
        "        self.metadata = data\n",
        "\n",
        "    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:\n",
        "        if query in self.query_cache:\n",
        "            query_embedding = self.query_cache[query]\n",
        "        else:\n",
        "            query_embedding = self.client.embed([query], model=\"voyage-2\").embeddings[0]\n",
        "            self.query_cache[query] = query_embedding\n",
        "\n",
        "        if not self.embeddings:\n",
        "            raise ValueError(\"No data loaded in the vector database.\")\n",
        "\n",
        "        similarities = np.dot(self.embeddings, query_embedding)\n",
        "        top_indices = np.argsort(similarities)[::-1][:k]\n",
        "\n",
        "        top_results = []\n",
        "        for idx in top_indices:\n",
        "            result = {\n",
        "                \"metadata\": self.metadata[idx],\n",
        "                \"similarity\": float(similarities[idx]),\n",
        "            }\n",
        "            top_results.append(result)\n",
        "\n",
        "        return top_results\n",
        "\n",
        "    def save_db(self):\n",
        "        data = {\n",
        "            \"embeddings\": self.embeddings,\n",
        "            \"metadata\": self.metadata,\n",
        "            \"query_cache\": json.dumps(self.query_cache),\n",
        "        }\n",
        "        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)\n",
        "        with open(self.db_path, \"wb\") as file:\n",
        "            pickle.dump(data, file)\n",
        "\n",
        "    def load_db(self):\n",
        "        if not os.path.exists(self.db_path):\n",
        "            raise ValueError(\"Vector database file not found. Use load_data to create a new database.\")\n",
        "        with open(self.db_path, \"rb\") as file:\n",
        "            data = pickle.load(file)\n",
        "        self.embeddings = data[\"embeddings\"]\n",
        "        self.metadata = data[\"metadata\"]\n",
        "        self.query_cache = json.loads(data[\"query_cache\"])\n",
        "\n",
        "    def validate_embedded_chunks(self):\n",
        "        unique_contents = set()\n",
        "        for meta in self.metadata:\n",
        "            unique_contents.add(meta['content'])\n",
        "\n",
        "        print(f\"Validation results:\")\n",
        "        print(f\"Total embedded chunks: {len(self.metadata)}\")\n",
        "        print(f\"Unique embedded contents: {len(unique_contents)}\")\n",
        "\n",
        "        if len(self.metadata) != len(unique_contents):\n",
        "            print(\"Warning: There may be duplicate chunks in the embedded data.\")\n",
        "        else:\n",
        "            print(\"All embedded chunks are unique.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UBty8NvDmj_k"
      },
      "outputs": [],
      "source": [
        "# Load your transformed dataset\n",
        "with open('data/codebase_chunks.json', 'r') as f:\n",
        "    transformed_dataset = json.load(f)\n",
        "\n",
        "# Initialize the VectorDB\n",
        "base_db = VectorDB(\"base_db\")\n",
        "\n",
        "# Load and process the data\n",
        "base_db.load_data(transformed_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lZAltnTKmj_k"
      },
      "source": [
        "## Basic RAG\n",
        "\n",
        "To get started, we'll set up a basic RAG pipeline using a bare bones approach. This is sometimes called 'Naive RAG' by many in the industry. A basic RAG pipeline includes the following 3 steps:\n",
        "\n",
        "1) Chunk documents by heading - containing only the content from each subheading\n",
        "\n",
        "2) Embed each document\n",
        "\n",
        "3) Use Cosine similarity to retrieve documents in order to answer query"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OqvKyjXCmj_k"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "from typing import List, Dict, Any, Callable, Union\n",
        "from tqdm import tqdm\n",
        "\n",
        "def load_jsonl(file_path: str) -> List[Dict[str, Any]]:\n",
        "    \"\"\"Load JSONL file and return a list of dictionaries.\"\"\"\n",
        "    with open(file_path, 'r') as file:\n",
        "        return [json.loads(line) for line in file]\n",
        "\n",
        "def evaluate_retrieval(queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20) -> Dict[str, float]:\n",
        "    total_score = 0\n",
        "    total_queries = len(queries)\n",
        "\n",
        "    for query_item in tqdm(queries, desc=\"Evaluating retrieval\"):\n",
        "        query = query_item['query']\n",
        "        golden_chunk_uuids = query_item['golden_chunk_uuids']\n",
        "\n",
        "        # Find all golden chunk contents\n",
        "        golden_contents = []\n",
        "        for doc_uuid, chunk_index in golden_chunk_uuids:\n",
        "            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)\n",
        "            if not golden_doc:\n",
        "                print(f\"Warning: Golden document not found for UUID {doc_uuid}\")\n",
        "                continue\n",
        "\n",
        "            golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)\n",
        "            if not golden_chunk:\n",
        "                print(f\"Warning: Golden chunk not found for index {chunk_index} in document {doc_uuid}\")\n",
        "                continue\n",
        "\n",
        "            golden_contents.append(golden_chunk['content'].strip())\n",
        "\n",
        "        if not golden_contents:\n",
        "            print(f\"Warning: No golden contents found for query: {query}\")\n",
        "            continue\n",
        "\n",
        "        retrieved_docs = retrieval_function(query, db, k=k)\n",
        "\n",
        "        # Count how many golden chunks are in the top k retrieved documents\n",
        "        chunks_found = 0\n",
        "        for golden_content in golden_contents:\n",
        "            for doc in retrieved_docs[:k]:\n",
        "                retrieved_content = doc['metadata'].get('original_content', doc['metadata'].get('content', '')).strip()\n",
        "                if retrieved_content == golden_content:\n",
        "                    chunks_found += 1\n",
        "                    break\n",
        "\n",
        "        query_score = chunks_found / len(golden_contents)\n",
        "        total_score += query_score\n",
        "\n",
        "    average_score = total_score / total_queries\n",
        "    pass_at_n = average_score * 100\n",
        "    return {\n",
        "        \"pass_at_n\": pass_at_n,\n",
        "        \"average_score\": average_score,\n",
        "        \"total_queries\": total_queries\n",
        "    }\n",
        "\n",
        "def retrieve_base(query: str, db, k: int = 20) -> List[Dict[str, Any]]:\n",
        "    \"\"\"\n",
        "    Retrieve relevant documents using either VectorDB or ContextualVectorDB.\n",
        "\n",
        "    :param query: The query string\n",
        "    :param db: The VectorDB or ContextualVectorDB instance\n",
        "    :param k: Number of top results to retrieve\n",
        "    :return: List of retrieved documents\n",
        "    \"\"\"\n",
        "    return db.search(query, k=k)\n",
        "\n",
        "def evaluate_db(db, original_jsonl_path: str, k):\n",
        "    # Load the original JSONL data for queries and ground truth\n",
        "    original_data = load_jsonl(original_jsonl_path)\n",
        "\n",
        "    # Evaluate retrieval\n",
        "    results = evaluate_retrieval(original_data, retrieve_base, db, k)\n",
        "    print(f\"Pass@{k}: {results['pass_at_n']:.2f}%\")\n",
        "    print(f\"Total Score: {results['average_score']}\")\n",
        "    print(f\"Total queries: {results['total_queries']}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0bRfJ8vomj_k"
      },
      "outputs": [],
      "source": [
        "results5 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 5)\n",
        "results10 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 10)\n",
        "results20 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2YXZ9NVlmj_k"
      },
      "source": [
        "## Contextual Embeddings\n",
        "\n",
        "With basic RAG, each embedded chunk contains a potentially useful piece of information, but these chunks lack context. With Contextual Embeddings, we create a variation on the embedding itself by adding more context to each text chunk before embedding it. Specifically, we use Claude to create a concise context that explains the chunk using the context of the overall document. In the case of our codebases dataset, we can provide both the chunk and the full file that each chunk was found within to an LLM, then produce the context. Then, we will combine this 'context' and the raw text chunk together into a single text block prior to creating each embedding.\n",
        "\n",
        "### Additional Considerations: Cost and Latency\n",
        "\n",
        "The extra work we're doing to 'situate' each document happens only at ingestion time: it's a cost you'll pay once when you store each document (and periodically in the future if you have a knowledge base that updates over time). There are many approaches like HyDE (hypothetical document embeddings) which involve performing steps to improve the representation of the query prior to executing a search. These techniques have shown to be moderately effective, but they add significant latency at runtime.\n",
        "\n",
        "Prompt caching also makes this much more cost effective. Creating contextual embeddings requires us to pass the same document to the model for every chunk we want to generate extra context for. With prompt caching, we can write the overall doc to the cache once, and then because we're doing our ingestion job all in sequence, we can just read the document from cache as we generate context for each chunk within that document (the information you write to the cache has a 5 minute time to live). This means that the first time we pass a document to the model, we pay a bit more to write it to the cache, but for each subsequent API call that contains that doc, we receive  a 90% discount on all of the input tokens read from the cache. Assuming 800 token chunks, 8k token documents, 50 token context instructions, and 100 tokens of context per chunk, the cost to generate contextualized chunks is $1.02 per million document tokens.\n",
        "\n",
        "When you load data into your ContextualVectorDB below, you'll see in logs just how big this impact is.\n",
        "\n",
        "Warning: some smaller embedding models have a fixed input token limit. Contextualizing the chunk makes it longer, so if you notice much worse performance from contextualized embeddings, the contextualized chunk is likely getting truncated"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fik_hlGFmj_l"
      },
      "outputs": [],
      "source": [
        "DOCUMENT_CONTEXT_PROMPT = \"\"\"\n",
        "<document>\n",
        "{doc_content}\n",
        "</document>\n",
        "\"\"\"\n",
        "\n",
        "CHUNK_CONTEXT_PROMPT = \"\"\"\n",
        "Here is the chunk we want to situate within the whole document\n",
        "<chunk>\n",
        "{chunk_content}\n",
        "</chunk>\n",
        "\n",
        "Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.\n",
        "Answer only with the succinct context and nothing else.\n",
        "\"\"\"\n",
        "\n",
        "def situate_context(doc: str, chunk: str) -> str:\n",
        "\n",
        "    portkey = Portkey(\n",
        "    api_key=\"PORTKEY_API_KEY\",\n",
        "    virtual_key=\"ANTHROPIC_VIRTUAL_KEY\",\n",
        "    anthropic_beta=\"prompt-caching-2024-07-31\"\n",
        ")\n",
        "\n",
        "\n",
        "    response = portkey.beta.prompt_caching.messages.create(\n",
        "        model=\"claude-3-haiku-20240307\",\n",
        "        max_tokens=1024,\n",
        "        temperature=0.0,\n",
        "        messages=[\n",
        "            {\n",
        "                \"role\": \"user\",\n",
        "                \"content\": [\n",
        "                    {\n",
        "                        \"type\": \"text\",\n",
        "                        \"text\": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),\n",
        "                        \"cache_control\": {\"type\": \"ephemeral\"}\n",
        "                    },\n",
        "                    {\n",
        "                        \"type\": \"text\",\n",
        "                        \"text\": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),\n",
        "                    }\n",
        "                ]\n",
        "            }\n",
        "        ],\n",
        "        extra_headers={\"anthropic-beta\": \"prompt-caching-2024-07-31\"}\n",
        "    )\n",
        "    return response\n",
        "\n",
        "jsonl_data = load_jsonl('data/evaluation_set.jsonl')\n",
        "# Example usage\n",
        "doc_content = jsonl_data[0]['golden_documents'][0]['content']\n",
        "chunk_content = jsonl_data[0]['golden_chunks'][0]['content']\n",
        "\n",
        "response = situate_context(doc_content, chunk_content)\n",
        "print(f\"Situated context: {response.content[0].text}\")\n",
        "\n",
        "# Print cache performance metrics\n",
        "print(f\"Input tokens: {response.usage.input_tokens}\")\n",
        "print(f\"Output tokens: {response.usage.output_tokens}\")\n",
        "print(f\"Cache creation input tokens: {response.usage.cache_creation_input_tokens}\")\n",
        "print(f\"Cache read input tokens: {response.usage.cache_read_input_tokens}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B10gcZrfmj_l"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pickle\n",
        "import json\n",
        "import numpy as np\n",
        "from typing import List, Dict, Any\n",
        "from tqdm import tqdm\n",
        "import threading\n",
        "import time\n",
        "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
        "\n",
        "class ContextualVectorDB:\n",
        "    def __init__(self, name: str, voyage_api_key=None, anthropic_api_key=None):\n",
        "        if voyage_api_key is None:\n",
        "            voyage_api_key = os.getenv(\"VOYAGE_API_KEY\")\n",
        "        if anthropic_api_key is None:\n",
        "            anthropic_api_key = os.getenv(\"ANTHROPIC_API_KEY\")\n",
        "        self.voyage_client = Portkey(\n",
        "            api_key=\"PORTKEY_API_KEY\",  # Replace with your Portkey API key\n",
        "            provider=\"voyage\",\n",
        "            Authorization=voyage_api_key\n",
        "        )\n",
        "        self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)\n",
        "        self.anthropic_client = Portkey(\n",
        "            api_key=\"PORTKEY_API_KEY\",  # Replace with your Portkey API key\n",
        "            provider=\"anthropic\",\n",
        "            Authorization=anthropic_api_key\n",
        "        )\n",
        "        self.name = name\n",
        "        self.embeddings = []\n",
        "        self.metadata = []\n",
        "        self.query_cache = {}\n",
        "        self.db_path = f\"./data/{name}/contextual_vector_db.pkl\"\n",
        "\n",
        "        self.token_counts = {\n",
        "            'input': 0,\n",
        "            'output': 0,\n",
        "            'cache_read': 0,\n",
        "            'cache_creation': 0\n",
        "        }\n",
        "        self.token_lock = threading.Lock()\n",
        "\n",
        "    def situate_context(self, doc: str, chunk: str) -> tuple[str, Any]:\n",
        "        DOCUMENT_CONTEXT_PROMPT = \"\"\"\n",
        "        <document>\n",
        "        {doc_content}\n",
        "        </document>\n",
        "        \"\"\"\n",
        "\n",
        "        CHUNK_CONTEXT_PROMPT = \"\"\"\n",
        "        Here is the chunk we want to situate within the whole document\n",
        "        <chunk>\n",
        "        {chunk_content}\n",
        "        </chunk>\n",
        "\n",
        "        Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.\n",
        "        Answer only with the succinct context and nothing else.\n",
        "        \"\"\"\n",
        "\n",
        "        response = self.anthropic_client.beta.prompt_caching.messages.create(\n",
        "            model=\"claude-3-haiku-20240307\",\n",
        "            max_tokens=1000,\n",
        "            temperature=0.0,\n",
        "            messages=[\n",
        "                {\n",
        "                    \"role\": \"user\",\n",
        "                    \"content\": [\n",
        "                        {\n",
        "                            \"type\": \"text\",\n",
        "                            \"text\": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),\n",
        "                            \"cache_control\": {\"type\": \"ephemeral\"}\n",
        "                        },\n",
        "                        {\n",
        "                            \"type\": \"text\",\n",
        "                            \"text\": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),\n",
        "                        },\n",
        "                    ]\n",
        "                },\n",
        "            ],\n",
        "            extra_headers={\"anthropic-beta\": \"prompt-caching-2024-07-31\"}\n",
        "        )\n",
        "        return response.content[0].text, response.usage\n",
        "\n",
        "    def load_data(self, dataset: List[Dict[str, Any]], parallel_threads: int = 1):\n",
        "        if self.embeddings and self.metadata:\n",
        "            print(\"Vector database is already loaded. Skipping data loading.\")\n",
        "            return\n",
        "        if os.path.exists(self.db_path):\n",
        "            print(\"Loading vector database from disk.\")\n",
        "            self.load_db()\n",
        "            return\n",
        "\n",
        "        texts_to_embed = []\n",
        "        metadata = []\n",
        "        total_chunks = sum(len(doc['chunks']) for doc in dataset)\n",
        "\n",
        "        def process_chunk(doc, chunk):\n",
        "            contextualized_text, usage = self.situate_context(doc['content'], chunk['content'])\n",
        "            with self.token_lock:\n",
        "                self.token_counts['input'] += usage.input_tokens\n",
        "                self.token_counts['output'] += usage.output_tokens\n",
        "                self.token_counts['cache_read'] += usage.cache_read_input_tokens\n",
        "                self.token_counts['cache_creation'] += usage.cache_creation_input_tokens\n",
        "\n",
        "            return {\n",
        "                'text_to_embed': f\"{chunk['content']}\\n\\n{contextualized_text}\",\n",
        "                'metadata': {\n",
        "                    'doc_id': doc['doc_id'],\n",
        "                    'original_uuid': doc['original_uuid'],\n",
        "                    'chunk_id': chunk['chunk_id'],\n",
        "                    'original_index': chunk['original_index'],\n",
        "                    'original_content': chunk['content'],\n",
        "                    'contextualized_content': contextualized_text\n",
        "                }\n",
        "            }\n",
        "\n",
        "        print(f\"Processing {total_chunks} chunks with {parallel_threads} threads\")\n",
        "        with ThreadPoolExecutor(max_workers=parallel_threads) as executor:\n",
        "            futures = []\n",
        "            for doc in dataset:\n",
        "                for chunk in doc['chunks']:\n",
        "                    futures.append(executor.submit(process_chunk, doc, chunk))\n",
        "\n",
        "            for future in tqdm(as_completed(futures), total=total_chunks, desc=\"Processing chunks\"):\n",
        "                result = future.result()\n",
        "                texts_to_embed.append(result['text_to_embed'])\n",
        "                metadata.append(result['metadata'])\n",
        "\n",
        "        self._embed_and_store(texts_to_embed, metadata)\n",
        "        self.save_db()\n",
        "\n",
        "        print(f\"Contextual Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}\")\n",
        "        print(f\"Total input tokens without caching: {self.token_counts['input']}\")\n",
        "        print(f\"Total output tokens: {self.token_counts['output']}\")\n",
        "        print(f\"Total input tokens written to cache: {self.token_counts['cache_creation']}\")\n",
        "        print(f\"Total input tokens read from cache: {self.token_counts['cache_read']}\")\n",
        "\n",
        "        total_tokens = self.token_counts['input'] + self.token_counts['cache_read'] + self.token_counts['cache_creation']\n",
        "        savings_percentage = (self.token_counts['cache_read'] / total_tokens) * 100 if total_tokens > 0 else 0\n",
        "        print(f\"Total input token savings from prompt caching: {savings_percentage:.2f}% of all input tokens used were read from cache.\")\n",
        "        print(\"Tokens read from cache come at a 90 percent discount!\")\n",
        "\n",
        "    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):\n",
        "        batch_size = 128\n",
        "        result = [\n",
        "            self.voyage_client.embed(\n",
        "                texts[i : i + batch_size],\n",
        "                model=\"voyage-2\"\n",
        "            ).embeddings\n",
        "            for i in range(0, len(texts), batch_size)\n",
        "        ]\n",
        "        self.embeddings = [embedding for batch in result for embedding in batch]\n",
        "        self.metadata = data\n",
        "\n",
        "    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:\n",
        "        if query in self.query_cache:\n",
        "            query_embedding = self.query_cache[query]\n",
        "        else:\n",
        "            query_embedding = self.voyage_client.embed([query], model=\"voyage-2\").embeddings[0]\n",
        "            self.query_cache[query] = query_embedding\n",
        "\n",
        "        if not self.embeddings:\n",
        "            raise ValueError(\"No data loaded in the vector database.\")\n",
        "\n",
        "        similarities = np.dot(self.embeddings, query_embedding)\n",
        "        top_indices = np.argsort(similarities)[::-1][:k]\n",
        "\n",
        "        top_results = []\n",
        "        for idx in top_indices:\n",
        "            result = {\n",
        "                \"metadata\": self.metadata[idx],\n",
        "                \"similarity\": float(similarities[idx]),\n",
        "            }\n",
        "            top_results.append(result)\n",
        "        return top_results\n",
        "\n",
        "    def save_db(self):\n",
        "        data = {\n",
        "            \"embeddings\": self.embeddings,\n",
        "            \"metadata\": self.metadata,\n",
        "            \"query_cache\": json.dumps(self.query_cache),\n",
        "        }\n",
        "        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)\n",
        "        with open(self.db_path, \"wb\") as file:\n",
        "            pickle.dump(data, file)\n",
        "\n",
        "    def load_db(self):\n",
        "        if not os.path.exists(self.db_path):\n",
        "            raise ValueError(\"Vector database file not found. Use load_data to create a new database.\")\n",
        "        with open(self.db_path, \"rb\") as file:\n",
        "            data = pickle.load(file)\n",
        "        self.embeddings = data[\"embeddings\"]\n",
        "        self.metadata = data[\"metadata\"]\n",
        "        self.query_cache = json.loads(data[\"query_cache\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tawxjiwmmj_l"
      },
      "outputs": [],
      "source": [
        "# Load the transformed dataset\n",
        "with open('data/codebase_chunks.json', 'r') as f:\n",
        "    transformed_dataset = json.load(f)\n",
        "\n",
        "# Initialize the ContextualVectorDB\n",
        "contextual_db = ContextualVectorDB(\"my_contextual_db\")\n",
        "\n",
        "# Load and process the data\n",
        "#note: consider increasing the number of parallel threads to run this faster, or reducing the number of parallel threads if concerned about hitting your API rate limit\n",
        "contextual_db.load_data(transformed_dataset, parallel_threads=5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lp1NJzoNmj_l"
      },
      "outputs": [],
      "source": [
        "r5 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 5)\n",
        "r10 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 10)\n",
        "r20 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "blvJY5rZmj_l"
      },
      "source": [
        "## Contextual BM25\n",
        "\n",
        "Contextual embeddings is an improvement on traditional semantic search RAG, but we can improve performance further. In this section we'll show you how you can use contextual embeddings and *contextual* BM25 together. While you can see performance gains by pairing these techniques together without the context, adding context to these methods reduces the top-20-chunk retrieval failure rate by 42%.\n",
        "\n",
        "BM25 is a probabilistic ranking function that improves upon TF-IDF. It scores documents based on query term frequency, while accounting for document length and term saturation. BM25 is widely used in modern search engines for its effectiveness in ranking relevant documents. For more details, see [this blog post]((https://www.elastic.co/blog/practical-bm25-part-2-the-bm25-algorithm-and-its-variables)). We'll use elastic search for the BM25 portion of this section, which will require you to have the elasticsearch library installed and it will also require you to spin up an Elasticsearch server in the background. The easiest way to do this is to install [docker](https://docs.docker.com/engine/install/) and run the following docker command:\n",
        "\n",
        "`docker run -d --name elasticsearch -p 9200:9200 -p 9300:9300 -e \"discovery.type=single-node\" -e \"xpack.security.enabled=false\" elasticsearch:8.8.0`\n",
        "\n",
        "One difference between a typical BM25 search and what we'll do in this section is that, for each chunk, we'll run each BM25 search on both the chunk content and the additional context that we generated in the previous section. From there, we'll use a technique called reciprocal rank fusion to merge the results from our BM25 search with our semantic search results. This allows us to perform a hybrid search across both our BM25 corpus and vector DB to return the most optimal documents for a given query.\n",
        "\n",
        "In the function below, we allow you the option to add weightings to the semantic search and BM25 search documents as you merge them with Reciprocal Rank Fusion. By default, we set these to 0.8 for the semantic search results and 0.2 to the BM25 results. We'd encourage you to experiment with different values here."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QAMh6LlSmj_l"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import json\n",
        "from typing import List, Dict, Any\n",
        "from tqdm import tqdm\n",
        "from elasticsearch import Elasticsearch\n",
        "from elasticsearch.helpers import bulk\n",
        "\n",
        "class ElasticsearchBM25:\n",
        "    def __init__(self, index_name: str = \"contextual_bm25_index\"):\n",
        "        self.es_client = Elasticsearch(\"http://localhost:9200\")\n",
        "        self.index_name = index_name\n",
        "        self.create_index()\n",
        "\n",
        "    def create_index(self):\n",
        "        index_settings = {\n",
        "            \"settings\": {\n",
        "                \"analysis\": {\"analyzer\": {\"default\": {\"type\": \"english\"}}},\n",
        "                \"similarity\": {\"default\": {\"type\": \"BM25\"}},\n",
        "                \"index.queries.cache.enabled\": False  # Disable query cache\n",
        "            },\n",
        "            \"mappings\": {\n",
        "                \"properties\": {\n",
        "                    \"content\": {\"type\": \"text\", \"analyzer\": \"english\"},\n",
        "                    \"contextualized_content\": {\"type\": \"text\", \"analyzer\": \"english\"},\n",
        "                    \"doc_id\": {\"type\": \"keyword\", \"index\": False},\n",
        "                    \"chunk_id\": {\"type\": \"keyword\", \"index\": False},\n",
        "                    \"original_index\": {\"type\": \"integer\", \"index\": False},\n",
        "                }\n",
        "            },\n",
        "        }\n",
        "        if not self.es_client.indices.exists(index=self.index_name):\n",
        "            self.es_client.indices.create(index=self.index_name, body=index_settings)\n",
        "            print(f\"Created index: {self.index_name}\")\n",
        "\n",
        "    def index_documents(self, documents: List[Dict[str, Any]]):\n",
        "        actions = [\n",
        "            {\n",
        "                \"_index\": self.index_name,\n",
        "                \"_source\": {\n",
        "                    \"content\": doc[\"original_content\"],\n",
        "                    \"contextualized_content\": doc[\"contextualized_content\"],\n",
        "                    \"doc_id\": doc[\"doc_id\"],\n",
        "                    \"chunk_id\": doc[\"chunk_id\"],\n",
        "                    \"original_index\": doc[\"original_index\"],\n",
        "                },\n",
        "            }\n",
        "            for doc in documents\n",
        "        ]\n",
        "        success, _ = bulk(self.es_client, actions)\n",
        "        self.es_client.indices.refresh(index=self.index_name)\n",
        "        return success\n",
        "\n",
        "    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:\n",
        "        self.es_client.indices.refresh(index=self.index_name)  # Force refresh before each search\n",
        "        search_body = {\n",
        "            \"query\": {\n",
        "                \"multi_match\": {\n",
        "                    \"query\": query,\n",
        "                    \"fields\": [\"content\", \"contextualized_content\"],\n",
        "                }\n",
        "            },\n",
        "            \"size\": k,\n",
        "        }\n",
        "        response = self.es_client.search(index=self.index_name, body=search_body)\n",
        "        return [\n",
        "            {\n",
        "                \"doc_id\": hit[\"_source\"][\"doc_id\"],\n",
        "                \"original_index\": hit[\"_source\"][\"original_index\"],\n",
        "                \"content\": hit[\"_source\"][\"content\"],\n",
        "                \"contextualized_content\": hit[\"_source\"][\"contextualized_content\"],\n",
        "                \"score\": hit[\"_score\"],\n",
        "            }\n",
        "            for hit in response[\"hits\"][\"hits\"]\n",
        "        ]\n",
        "\n",
        "def create_elasticsearch_bm25_index(db: ContextualVectorDB):\n",
        "    es_bm25 = ElasticsearchBM25()\n",
        "    es_bm25.index_documents(db.metadata)\n",
        "    return es_bm25\n",
        "\n",
        "def retrieve_advanced(query: str, db: ContextualVectorDB, es_bm25: ElasticsearchBM25, k: int, semantic_weight: float = 0.8, bm25_weight: float = 0.2):\n",
        "    num_chunks_to_recall = 150\n",
        "\n",
        "    # Semantic search\n",
        "    semantic_results = db.search(query, k=num_chunks_to_recall)\n",
        "    ranked_chunk_ids = [(result['metadata']['doc_id'], result['metadata']['original_index']) for result in semantic_results]\n",
        "\n",
        "    # BM25 search using Elasticsearch\n",
        "    bm25_results = es_bm25.search(query, k=num_chunks_to_recall)\n",
        "    ranked_bm25_chunk_ids = [(result['doc_id'], result['original_index']) for result in bm25_results]\n",
        "\n",
        "    # Combine results\n",
        "    chunk_ids = list(set(ranked_chunk_ids + ranked_bm25_chunk_ids))\n",
        "    chunk_id_to_score = {}\n",
        "\n",
        "    # Initial scoring with weights\n",
        "    for chunk_id in chunk_ids:\n",
        "        score = 0\n",
        "        if chunk_id in ranked_chunk_ids:\n",
        "            index = ranked_chunk_ids.index(chunk_id)\n",
        "            score += semantic_weight * (1 / (index + 1))  # Weighted 1/n scoring for semantic\n",
        "        if chunk_id in ranked_bm25_chunk_ids:\n",
        "            index = ranked_bm25_chunk_ids.index(chunk_id)\n",
        "            score += bm25_weight * (1 / (index + 1))  # Weighted 1/n scoring for BM25\n",
        "        chunk_id_to_score[chunk_id] = score\n",
        "\n",
        "    # Sort chunk IDs by their scores in descending order\n",
        "    sorted_chunk_ids = sorted(\n",
        "        chunk_id_to_score.keys(), key=lambda x: (chunk_id_to_score[x], x[0], x[1]), reverse=True\n",
        "    )\n",
        "\n",
        "    # Assign new scores based on the sorted order\n",
        "    for index, chunk_id in enumerate(sorted_chunk_ids):\n",
        "        chunk_id_to_score[chunk_id] = 1 / (index + 1)\n",
        "\n",
        "    # Prepare the final results\n",
        "    final_results = []\n",
        "    semantic_count = 0\n",
        "    bm25_count = 0\n",
        "    for chunk_id in sorted_chunk_ids[:k]:\n",
        "        chunk_metadata = next(chunk for chunk in db.metadata if chunk['doc_id'] == chunk_id[0] and chunk['original_index'] == chunk_id[1])\n",
        "        is_from_semantic = chunk_id in ranked_chunk_ids\n",
        "        is_from_bm25 = chunk_id in ranked_bm25_chunk_ids\n",
        "        final_results.append({\n",
        "            'chunk': chunk_metadata,\n",
        "            'score': chunk_id_to_score[chunk_id],\n",
        "            'from_semantic': is_from_semantic,\n",
        "            'from_bm25': is_from_bm25\n",
        "        })\n",
        "\n",
        "        if is_from_semantic and not is_from_bm25:\n",
        "            semantic_count += 1\n",
        "        elif is_from_bm25 and not is_from_semantic:\n",
        "            bm25_count += 1\n",
        "        else:  # it's in both\n",
        "            semantic_count += 0.5\n",
        "            bm25_count += 0.5\n",
        "\n",
        "    return final_results, semantic_count, bm25_count\n",
        "\n",
        "def load_jsonl(file_path: str) -> List[Dict[str, Any]]:\n",
        "    with open(file_path, 'r') as file:\n",
        "        return [json.loads(line) for line in file]\n",
        "\n",
        "def evaluate_db_advanced(db: ContextualVectorDB, original_jsonl_path: str, k: int):\n",
        "    original_data = load_jsonl(original_jsonl_path)\n",
        "    es_bm25 = create_elasticsearch_bm25_index(db)\n",
        "\n",
        "    try:\n",
        "        # Warm-up queries\n",
        "        warm_up_queries = original_data[:10]\n",
        "        for query_item in warm_up_queries:\n",
        "            _ = retrieve_advanced(query_item['query'], db, es_bm25, k)\n",
        "\n",
        "        total_score = 0\n",
        "        total_semantic_count = 0\n",
        "        total_bm25_count = 0\n",
        "        total_results = 0\n",
        "\n",
        "        for query_item in tqdm(original_data, desc=\"Evaluating retrieval\"):\n",
        "            query = query_item['query']\n",
        "            golden_chunk_uuids = query_item['golden_chunk_uuids']\n",
        "\n",
        "            golden_contents = []\n",
        "            for doc_uuid, chunk_index in golden_chunk_uuids:\n",
        "                golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)\n",
        "                if golden_doc:\n",
        "                    golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)\n",
        "                    if golden_chunk:\n",
        "                        golden_contents.append(golden_chunk['content'].strip())\n",
        "\n",
        "            if not golden_contents:\n",
        "                print(f\"Warning: No golden contents found for query: {query}\")\n",
        "                continue\n",
        "\n",
        "            retrieved_docs, semantic_count, bm25_count = retrieve_advanced(query, db, es_bm25, k)\n",
        "\n",
        "            chunks_found = 0\n",
        "            for golden_content in golden_contents:\n",
        "                for doc in retrieved_docs[:k]:\n",
        "                    retrieved_content = doc['chunk']['original_content'].strip()\n",
        "                    if retrieved_content == golden_content:\n",
        "                        chunks_found += 1\n",
        "                        break\n",
        "\n",
        "            query_score = chunks_found / len(golden_contents)\n",
        "            total_score += query_score\n",
        "\n",
        "            total_semantic_count += semantic_count\n",
        "            total_bm25_count += bm25_count\n",
        "            total_results += len(retrieved_docs)\n",
        "\n",
        "        total_queries = len(original_data)\n",
        "        average_score = total_score / total_queries\n",
        "        pass_at_n = average_score * 100\n",
        "\n",
        "        semantic_percentage = (total_semantic_count / total_results) * 100 if total_results > 0 else 0\n",
        "        bm25_percentage = (total_bm25_count / total_results) * 100 if total_results > 0 else 0\n",
        "\n",
        "        results = {\n",
        "            \"pass_at_n\": pass_at_n,\n",
        "            \"average_score\": average_score,\n",
        "            \"total_queries\": total_queries\n",
        "        }\n",
        "\n",
        "        print(f\"Pass@{k}: {pass_at_n:.2f}%\")\n",
        "        print(f\"Average Score: {average_score:.2f}\")\n",
        "        print(f\"Total queries: {total_queries}\")\n",
        "        print(f\"Percentage of results from semantic search: {semantic_percentage:.2f}%\")\n",
        "        print(f\"Percentage of results from BM25: {bm25_percentage:.2f}%\")\n",
        "\n",
        "        return results, {\"semantic\": semantic_percentage, \"bm25\": bm25_percentage}\n",
        "\n",
        "    finally:\n",
        "        # Delete the Elasticsearch index\n",
        "        if es_bm25.es_client.indices.exists(index=es_bm25.index_name):\n",
        "            es_bm25.es_client.indices.delete(index=es_bm25.index_name)\n",
        "            print(f\"Deleted Elasticsearch index: {es_bm25.index_name}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lOrataaymj_l"
      },
      "outputs": [],
      "source": [
        "results5 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 5)\n",
        "results10 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 10)\n",
        "results20 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O9SrMxsemj_m"
      },
      "source": [
        "## Adding a Re-Ranking Step\n",
        "\n",
        "If you want to improve performance further, we recommend adding a re-ranking step. When using a re-ranker, you can retrieve more documents initially from your vector store, then use your re-ranker to select a subset of these documents. One common technique is to use re-ranking as a way to implement high precision hybrid search. You can use a combination of semantic search and keyword based search in your initial retrieval step (as we have done earlier in this guide), then use a re-ranking step to choose only the k most relevant docs from a combined list of documents returned by your semantic search and keyword search systems.\n",
        "\n",
        "Below, we'll demonstrate only the re-ranking step (skipping the hybrid search technique for now). You'll see that we retrieve 10x the number of documents than the number of final k documents we want to retrieve, then use a re-ranking model from Cohere to select the 10 most relevant results from that list. Adding the re-ranking step delivers a modest additional gain in performance. In our case, Pass@10 improves from 92.81% --> 94.79%."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e9n9KepUmj_m"
      },
      "outputs": [],
      "source": [
        "from typing import List, Dict, Any, Callable\n",
        "import json\n",
        "from portkey_ai import Portkey\n",
        "\n",
        "def load_jsonl(file_path: str) -> List[Dict[str, Any]]:\n",
        "    with open(file_path, 'r') as file:\n",
        "        return [json.loads(line) for line in file]\n",
        "\n",
        "def chunk_to_content(chunk: Dict[str, Any]) -> str:\n",
        "    original_content = chunk['metadata']['original_content']\n",
        "    contextualized_content = chunk['metadata']['contextualized_content']\n",
        "    return f\"{original_content}\\n\\nContext: {contextualized_content}\"\n",
        "\n",
        "def retrieve_rerank(query: str, db, k: int) -> List[Dict[str, Any]]:\n",
        "    co = Portkey(\n",
        "            api_key=\"PORTKEY_API_KEY\",  # Replace with your Portkey API key\n",
        "            provider=\"cohere\",\n",
        "            Authorization=os.getenv(\"COHERE_API_KEY\")\n",
        "        )\n",
        "\n",
        "    # Retrieve more results than we normally would\n",
        "    semantic_results = db.search(query, k=k*10)\n",
        "\n",
        "    # Extract documents for reranking, using the contextualized content\n",
        "    documents = [chunk_to_content(res) for res in semantic_results]\n",
        "\n",
        "    response = portkey.post(\n",
        "        \"/rerank\",\n",
        "        model=\"rerank-english-v3.0\",\n",
        "        query=query,\n",
        "        documents=documents,\n",
        "        top_n=k\n",
        "    )\n",
        "    time.sleep(0.1)\n",
        "\n",
        "    final_results = []\n",
        "    for r in response.results:\n",
        "        original_result = semantic_results[r.index]\n",
        "        final_results.append({\n",
        "            \"chunk\": original_result['metadata'],\n",
        "            \"score\": r.relevance_score\n",
        "        })\n",
        "\n",
        "    return final_results\n",
        "\n",
        "def evaluate_retrieval_rerank(queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20) -> Dict[str, float]:\n",
        "    total_score = 0\n",
        "    total_queries = len(queries)\n",
        "\n",
        "    for query_item in tqdm(queries, desc=\"Evaluating retrieval\"):\n",
        "        query = query_item['query']\n",
        "        golden_chunk_uuids = query_item['golden_chunk_uuids']\n",
        "\n",
        "        golden_contents = []\n",
        "        for doc_uuid, chunk_index in golden_chunk_uuids:\n",
        "            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)\n",
        "            if golden_doc:\n",
        "                golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)\n",
        "                if golden_chunk:\n",
        "                    golden_contents.append(golden_chunk['content'].strip())\n",
        "\n",
        "        if not golden_contents:\n",
        "            print(f\"Warning: No golden contents found for query: {query}\")\n",
        "            continue\n",
        "\n",
        "        retrieved_docs = retrieval_function(query, db, k)\n",
        "\n",
        "        chunks_found = 0\n",
        "        for golden_content in golden_contents:\n",
        "            for doc in retrieved_docs[:k]:\n",
        "                retrieved_content = doc['chunk']['original_content'].strip()\n",
        "                if retrieved_content == golden_content:\n",
        "                    chunks_found += 1\n",
        "                    break\n",
        "\n",
        "        query_score = chunks_found / len(golden_contents)\n",
        "        total_score += query_score\n",
        "\n",
        "    average_score = total_score / total_queries\n",
        "    pass_at_n = average_score * 100\n",
        "    return {\n",
        "        \"pass_at_n\": pass_at_n,\n",
        "        \"average_score\": average_score,\n",
        "        \"total_queries\": total_queries\n",
        "    }\n",
        "\n",
        "def evaluate_db_advanced(db, original_jsonl_path, k):\n",
        "    original_data = load_jsonl(original_jsonl_path)\n",
        "\n",
        "    def retrieval_function(query, db, k):\n",
        "        return retrieve_rerank(query, db, k)\n",
        "\n",
        "    results = evaluate_retrieval_rerank(original_data, retrieval_function, db, k)\n",
        "    print(f\"Pass@{k}: {results['pass_at_n']:.2f}%\")\n",
        "    print(f\"Average Score: {results['average_score']}\")\n",
        "    print(f\"Total queries: {results['total_queries']}\")\n",
        "    return results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TAs4ELeHmj_m"
      },
      "outputs": [],
      "source": [
        "results5 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 5)\n",
        "results10 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 10)\n",
        "results20 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lokC8OVMmj_m"
      },
      "source": [
        "### Next Steps and Key Takeaways\n",
        "\n",
        "1) We demonstrated how to use Contextual Embeddings to improve retrieval performance, then delivered additional improvements with Contextual BM25 and reranking.\n",
        "\n",
        "2) This example used codebases, but these methods also apply to other data types such as internal company knowledge bases, financial & legal content, educational content, and much more.\n",
        "\n",
        "3) If you are an AWS user, you can get started with the Lambda function in `contextual-embeddings-lambda-function`, and if you're a GCP user you can spin up your own Cloud Run instance and follow a similar pattern!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q3depvmimj_m"
      },
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "py311",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.6"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}